#include "netlib/net/socket/connector.h"

#include <sys/socket.h>

#include <functional>

#include "netlib/base/logger.h"
#include "netlib/net/event/channel.h"
#include "netlib/net/event/event_loop.h"
#include "netlib/net/socket/socket_ops.h"

netlib::net::Connector::Connector(EventLoop* loop, const InetAddress& addr) :
    loop_(loop), server_addr_(addr) {}

netlib::net::Connector::~Connector() { assert(!channel_); }

void netlib::net::Connector::Start() {
	assert(state_ == kDisconnected);
	is_connect_ = true;
	loop_->RunInLoop([this] { StartInLoop(); });
}

bool netlib::net::Connector::StartUntil() {
	auto future = connect_promise_.get_future();
	Start();
	return future.get();
}

void netlib::net::Connector::Restart() {
	loop_->AssertInLoopThread();
	// assert(state_ = kDisconnected);
	state_ = kDisconnected;
	is_connect_ = true;
	retry_delay_ms_ = kInitRetryDelayMs;
	StartInLoop();
}

void netlib::net::Connector::Stop() {
	is_connect_ = false;
	loop_->RunInLoop([this] { StopInLoop(); });
}

// void netlib::net::Connector::Retry() {
//   assert(state_ == kDisconnected);
//   if (is_connect_) {
//     LOG_INFO << "Connector::retry - Retry connecting to "
//              << server_addr_.ToIpPort() << " in " << retry_delay_ms_
//              << " milliseconds.";
//     loop_->RunAfter(retry_delay_ms_ / 1000.0, [this] { StartInLoop(); });
//     retry_delay_ms_ = std::min(retry_delay_ms_ * 2, kMaxRetryDelayMs);
//   } else {
//     LOG_DEBUG << "do not connect";
//   }
// }

void netlib::net::Connector::StartInLoop() {
	loop_->AssertInLoopThread();
	assert(state_ == kDisconnected);
	if (!is_connect_) {
		LOG_DEBUG << "do not connect";
		return;
	}
	int sockfd = ::socket(server_addr_.Family(), SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
	//  SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, IPPROTO_TCP);
	assert(sockfd > 0);
	int ret = sockets::Connect(sockfd, server_addr_.GetSockaddr());
	int saved_errno = (ret == 0) ? 0 : errno;

	switch (saved_errno) {
	case 0:
	case EINTR:
	case EISCONN:
	case EINPROGRESS:
		Connecting(sockfd);
		break;

	case EAGAIN:
	case EADDRINUSE:
	case EADDRNOTAVAIL:
	case ECONNREFUSED:
	case ENETUNREACH:
		sockets::Close(sockfd);
		LOG_SYSERR << "connect error in Connector::startInLoop " << saved_errno;
		// Retry();
		break;

	case EACCES:
	case EPERM:
	case EAFNOSUPPORT:
	case EALREADY:
	case EBADF:
	case EFAULT:
	case ENOTSOCK:
		LOG_SYSERR << "connect error in Connector::startInLoop " << saved_errno;
		sockets::Close(sockfd);
		break;

	default:
		LOG_SYSERR << "Unexpected error in Connector::startInLoop " << saved_errno;
		sockets::Close(sockfd);
		// connectErrorCallback_();
		break;
	}
}

void netlib::net::Connector::StopInLoop() {
	loop_->AssertInLoopThread();
	if (kConnecting == state_) {
		state_ = kDisconnected;
		int sockfd = RemoveChannel();
		sockets::Close(sockfd);
		// Retry();
	}
}

void netlib::net::Connector::Connecting(int sockfd) {
	assert(!channel_);
	state_ = kConnecting;
	assert(!channel_);
	channel_ = new Channel(loop_, sockfd);
	channel_->SetWriteCallback([this] { HandleWrite(); });
	channel_->SetErrorCallback([this] { HandleError(); });
	channel_->EnableWriting();
}

int netlib::net::Connector::RemoveChannel() {
	assert(channel_);
	channel_->DisableAll();
	channel_->Remove();
	int sockfd = channel_->Fd();
	loop_->QueueInLoop([this] { DeleteChannel(); });
	return sockfd;
}

void netlib::net::Connector::DeleteChannel() {
	assert(channel_);
	delete channel_;
	channel_ = nullptr;
}

void netlib::net::Connector::HandleWrite() {
	loop_->AssertInLoopThread();
	assert(state_ == kConnecting);
	int sockfd = RemoveChannel();
	// TODO(longfar): get socket errno and retry
	if (sockets::IsSelfConnect(sockfd)) {
		LOG_WARN << "Connector::handleWrite - Self connect";
		sockets::Close(sockfd);
		state_ = kDisconnected;
		// Retry();
	} else {
		state_ = kConnected;
		if (is_connect_) {
			callback_(sockfd);
		} else {
			sockets::Close(sockfd);
		}
		connect_promise_.set_value(is_connect_);
	}
}

void netlib::net::Connector::HandleError() {
	LOG_ERROR << "Connector::handleError state=" << state_;
	state_ = kDisconnected;
	// if (state_ == kConnecting) {
	//   int sockfd = RemoveChannel();
	//   int err = sockets::GetSocketError(sockfd);
	//   LOG_TRACE << "SO_ERROR = " << err << " " << StrerrorTl(err);
	//   sockets::Close(sockfd);
	//   state_ = kDisconnected;
	//   // Retry();
	// }
}
